{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython import display\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms, datasets\n",
    "\n",
    "from utils import Logger\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_v2_behavior() \n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_FOLDER = './tf_data/VGAN/MNIST'\n",
    "IMAGE_PIXELS = 28*28\n",
    "NOISE_SIZE = 100\n",
    "BATCH_SIZE = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def noise(n_rows, n_cols):\n",
    "    return np.random.normal(size=(n_rows, n_cols))\n",
    "\n",
    "def xavier_init(size):\n",
    "    in_dim = size[0] if len(size) == 1 else size[1]\n",
    "    stddev = 1. / np.sqrt(float(in_dim))\n",
    "    return tf.random_uniform(shape=size, minval=-stddev, maxval=stddev)\n",
    "\n",
    "def images_to_vectors(images):\n",
    "    return images.reshape(images.shape[0], 784)\n",
    "\n",
    "def vectors_to_images(vectors):\n",
    "    return vectors.reshape(vectors.shape[0], 28, 28, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def mnist_data():\n",
    "    compose = transforms.Compose(\n",
    "        [transforms.ToTensor(),\n",
    "         transforms.Normalize((.5,), (.5,))\n",
    "        ])\n",
    "    out_dir = '{}/dataset'.format(DATA_FOLDER)\n",
    "    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "data = mnist_data()\n",
    "# Create loader with data, so that we can iterate over it\n",
    "data_loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)\n",
    "# Num batches\n",
    "num_batches = len(data_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize Graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Discriminator\n",
    "\n",
    "# Input\n",
    "X = tf.placeholder(tf.float32, shape=(None, IMAGE_PIXELS))\n",
    "\n",
    "# Layer 1 Variables\n",
    "D_W1 = tf.Variable(xavier_init([784, 1024]))\n",
    "D_B1 = tf.Variable(xavier_init([1024]))\n",
    "\n",
    "# Layer 2 Variables\n",
    "D_W2 = tf.Variable(xavier_init([1024, 512]))\n",
    "D_B2 = tf.Variable(xavier_init([512]))\n",
    "\n",
    "# Layer 3 Variables\n",
    "D_W3 = tf.Variable(xavier_init([512, 256]))\n",
    "D_B3 = tf.Variable(xavier_init([256]))\n",
    "\n",
    "# Out Layer Variables\n",
    "D_W4 = tf.Variable(xavier_init([256, 1]))\n",
    "D_B4 = tf.Variable(xavier_init([1]))\n",
    "\n",
    "# Store Variables in list\n",
    "D_var_list = [D_W1, D_B1, D_W2, D_B2, D_W3, D_B3, D_W4, D_B4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Generator\n",
    "\n",
    "# Input\n",
    "Z = tf.placeholder(tf.float32, shape=(None, NOISE_SIZE))\n",
    "\n",
    "# Layer 1 Variables\n",
    "G_W1 = tf.Variable(xavier_init([100, 256]))\n",
    "G_B1 = tf.Variable(xavier_init([256]))\n",
    "\n",
    "# Layer 2 Variables\n",
    "G_W2 = tf.Variable(xavier_init([256, 512]))\n",
    "G_B2 = tf.Variable(xavier_init([512]))\n",
    "\n",
    "# Layer 3 Variables\n",
    "G_W3 = tf.Variable(xavier_init([512, 1024]))\n",
    "G_B3 = tf.Variable(xavier_init([1024]))\n",
    "\n",
    "# Out Layer Variables\n",
    "G_W4 = tf.Variable(xavier_init([1024, 784]))\n",
    "G_B4 = tf.Variable(xavier_init([784]))\n",
    "\n",
    "# Store Variables in list\n",
    "G_var_list = [G_W1, G_B1, G_W2, G_B2, G_W3, G_B3, G_W4, G_B4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def discriminator(x):\n",
    "    l1 = tf.nn.dropout(tf.nn.leaky_relu(tf.matmul(x,   D_W1) + D_B1, .2), .3)\n",
    "    l2 = tf.nn.dropout(tf.nn.leaky_relu(tf.matmul(l1,  D_W2) + D_B2, .2), .3)\n",
    "    l3 = tf.nn.dropout(tf.nn.leaky_relu(tf.matmul(l2,  D_W3) + D_B3, .2), .3)\n",
    "    out = tf.matmul(l3, D_W4) + D_B4\n",
    "    return out\n",
    "\n",
    "def generator(z):\n",
    "    l1 = tf.nn.leaky_relu(tf.matmul(z,  G_W1) + G_B1, .2)\n",
    "    l2 = tf.nn.leaky_relu(tf.matmul(l1, G_W2) + G_B2, .2)\n",
    "    l3 = tf.nn.leaky_relu(tf.matmul(l2, G_W3) + G_B3, .2)\n",
    "    out = tf.nn.tanh(tf.matmul(l3, G_W4) + G_B4)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G_sample = generator(Z)\n",
    "D_real = discriminator(X)\n",
    "D_fake = discriminator(G_sample)\n",
    "\n",
    "# Losses\n",
    "D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real, labels=tf.ones_like(D_real)))\n",
    "D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake, labels=tf.zeros_like(D_fake)))\n",
    "D_loss = D_loss_real + D_loss_fake\n",
    "G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake, labels=tf.ones_like(D_fake)))\n",
    "\n",
    "# Optimizers\n",
    "D_opt = tf.train.AdamOptimizer(2e-4).minimize(D_loss, var_list=D_var_list)\n",
    "G_opt = tf.train.AdamOptimizer(2e-4).minimize(G_loss, var_list=G_var_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_samples = 16\n",
    "test_noise = noise(num_test_samples, NOISE_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 200\n",
    "\n",
    "# Start interactive session\n",
    "session = tf.InteractiveSession()\n",
    "# Init Variables\n",
    "tf.global_variables_initializer().run()\n",
    "# Init Logger\n",
    "logger = Logger(model_name='DCGAN1', data_name='CIFAR10')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Iterate through epochs\n",
    "for epoch in range(num_epochs):\n",
    "    for n_batch, (batch,_) in enumerate(data_loader):\n",
    "        \n",
    "        # 1. Train Discriminator\n",
    "        X_batch = images_to_vectors(batch.permute(0, 2, 3, 1).numpy())\n",
    "        feed_dict = {X: X_batch, Z: noise(BATCH_SIZE, NOISE_SIZE)}\n",
    "        _, d_error, d_pred_real, d_pred_fake = session.run(\n",
    "            [D_opt, D_loss, D_real, D_fake], feed_dict=feed_dict\n",
    "        )\n",
    "\n",
    "        # 2. Train Generator\n",
    "        feed_dict = {Z: noise(BATCH_SIZE, NOISE_SIZE)}\n",
    "        _, g_error = session.run(\n",
    "            [G_opt, G_loss], feed_dict=feed_dict\n",
    "        )\n",
    "\n",
    "        if n_batch % 100 == 0:\n",
    "            display.clear_output(True)\n",
    "            # Generate images from test noise\n",
    "            test_images = session.run(\n",
    "                G_sample, feed_dict={Z: test_noise}\n",
    "            )\n",
    "            test_images = vectors_to_images(test_images)\n",
    "            # Log Images\n",
    "            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches, format='NHWC');\n",
    "            # Log Status\n",
    "            logger.display_status(\n",
    "                epoch, num_epochs, n_batch, num_batches,\n",
    "                d_error, g_error, d_pred_real, d_pred_fake\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
